import torch.nn as nn
import torch
import copy
import math
import torch.nn.functional as F
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

class LoraLinear(nn.Module):
    def __init__(
        self,
        base_layer: nn.Linear,      # 原来的线性层
        r: int = 8,                 # lora rank
        alpha: int = 16,            # lora alpha
        dropout_p: float = 0.0,     # lora dropout
        test_mode: bool = False,    # 测试模式，用于控制 lora_B 是否为全零
    ):
        super(LoraLinear, self).__init__()
        self.base_layer = copy.deepcopy(base_layer)
        self.r = r
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_p)

        # 定义 lora_A 和 lora_B 为 Parameter
        self.lora_A = nn.Parameter(torch.empty((r, base_layer.in_features), dtype=base_layer.weight.dtype))
        self.lora_B = nn.Parameter(torch.empty((base_layer.out_features, r), dtype=base_layer.weight.dtype))

        # 初始化 lora 矩阵
        nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
        if test_mode:
            nn.init.normal_(self.lora_B, mean=0.0, std=0.02)
        else:
            nn.init.zeros_(self.lora_B)

        # 冻结原来的层的参数
        for param in self.base_layer.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scaling = float(self.alpha) / float(self.r)     # lora 缩放系数
        lora_adjustment = F.linear(self.dropout(x), self.lora_A)
        lora_adjustment = F.linear(lora_adjustment, self.lora_B)
        return self.base_layer(x) + lora_adjustment * scaling

class PlainMultiheadAttentionLoRA(nn.Module):
    def __init__(
            self,
            base_layer: nn.MultiheadAttention,
            r: int = 0, 
            lora_alpha: int = 1, 
            dropout_rate:float = 0.,
            test_mode: bool = False,
        ):
        super().__init__()
        self.base_layer = copy.deepcopy(base_layer)
        
        self.dropout = 0
        self.embed_dim = self.base_layer.embed_dim
        self.batch_first = self.base_layer.batch_first
        self.alpha = lora_alpha
        self.r = r
        self._qkv_same_embed_dim = self.base_layer._qkv_same_embed_dim
        self.lora_A_q = nn.Parameter(torch.empty((r, self.embed_dim), dtype=base_layer.in_proj_weight.dtype))
        self.lora_B_q = nn.Parameter(torch.empty((self.embed_dim, r), dtype=base_layer.in_proj_weight.dtype))
        self.lora_A_k = nn.Parameter(torch.empty((r, self.embed_dim), dtype=base_layer.in_proj_weight.dtype))
        self.lora_B_k = nn.Parameter(torch.empty((self.embed_dim, r), dtype=base_layer.in_proj_weight.dtype))
        self.lora_A_v = nn.Parameter(torch.empty((r, self.embed_dim), dtype=base_layer.in_proj_weight.dtype))
        self.lora_B_v = nn.Parameter(torch.empty((self.embed_dim, r), dtype=base_layer.in_proj_weight.dtype))
        self.lora_A_o = nn.Parameter(torch.empty((r, base_layer.out_proj.in_features), dtype=base_layer.out_proj.weight.dtype))
        self.lora_B_o = nn.Parameter(torch.empty((base_layer.out_proj.out_features, r), dtype=base_layer.out_proj.weight.dtype))
        
        nn.init.kaiming_uniform_(self.lora_A_q, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_A_k, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_A_v, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_A_o, a=math.sqrt(5))
        if test_mode:
            nn.init.kaiming_uniform_(self.lora_B_q, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.lora_B_k, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.lora_B_v, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.lora_B_o, a=math.sqrt(5))
        else:
            nn.init.zeros_(self.lora_B_q)
            nn.init.zeros_(self.lora_B_k)
            nn.init.zeros_(self.lora_B_v)
            nn.init.zeros_(self.lora_B_o)
                
        for param in self.base_layer.parameters():
            param.requires_grad = False
        
    def forward_module(
            self,
            query,
            key,
            value,
            key_padding_mask=None,
            need_weights=True,
            attn_mask=None,
            average_attn_weights=True):

        is_batched = query.dim() == 3
        if key_padding_mask is not None:
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported")
        why_not_fast_path = ''
        if not is_batched:
            why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
        elif query is not key or key is not value:
            # When lifting this restriction, don't forget to either
            # enforce that the dtypes all match or test cases where
            # they don't!
            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
        elif self.base_layer.in_proj_bias is not None and query.dtype != self.base_layer.in_proj_bias.dtype:
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.base_layer.in_proj_bias.dtype}) don't match"
        elif self.base_layer.in_proj_weight is not None and query.dtype != self.base_layer.in_proj_weight.dtype:
            # this case will fail anyway, but at least they'll get a useful error message.
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.base_layer.in_proj_weight.dtype}) don't match"
        elif self.training:
            why_not_fast_path = "training is enabled"
        elif not self.batch_first:
            why_not_fast_path = "batch_first was not True"
        elif self.base_layer.bias_k is not None:
            why_not_fast_path = "self.bias_k was not None"
        elif self.base_layer.bias_v is not None:
            why_not_fast_path = "self.bias_v was not None"
        elif self.base_layer.dropout:
            why_not_fast_path = f"dropout was {self.dropout}, required zero"
        elif self.base_layer.add_zero_attn:
            why_not_fast_path = "add_zero_attn was enabled"
        elif not self.base_layer._qkv_same_embed_dim:
            why_not_fast_path = "_qkv_same_embed_dim was not True"
        elif attn_mask is not None:
            why_not_fast_path = "attn_mask was not None"
        elif query.is_nested and key_padding_mask is not None:
            why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
        elif self.base_layer.num_heads % 2 == 1:
            why_not_fast_path = "num_heads is odd"
        elif torch.is_autocast_enabled():
            why_not_fast_path = "autocast is enabled"

        if not why_not_fast_path:
            assert False
            tensor_args = (
                query,
                key,
                value,
                self.base_layer.in_proj_weight,
                self.base_layer.in_proj_bias,
                self.base_layer.out_proj.weight,
                self.base_layer.out_proj.bias,
            )
            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_fast_path = "some Tensor argument has_torch_function"
            elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
                why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
                why_not_fast_path = ("grad is enabled and at least one of query or the "
                                     "input/output projection weights or biases requires_grad")
            if not why_not_fast_path:
                return torch._native_multi_head_attention(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.base_layer.num_heads,
                    self.base_layer.in_proj_weight,
                    self.base_layer.in_proj_bias,
                    self.base_layer.out_proj.weight,
                    self.base_layer.out_proj.bias,
                    key_padding_mask if key_padding_mask is not None else attn_mask,
                    need_weights,
                    average_attn_weights,
                    1 if key_padding_mask is not None else 0 if attn_mask is not None else None)

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
        q_weight = torch.mm(self.lora_B_q, self.lora_A_q) * (self.alpha / self.r)
        k_weight = torch.mm(self.lora_B_k, self.lora_A_k) * (self.alpha / self.r)
        v_weight = torch.mm(self.lora_B_v, self.lora_A_v) * (self.alpha / self.r)
        in_proj_weight = torch.cat([q_weight, k_weight, v_weight])
        o_weight = torch.mm(self.lora_B_o, self.lora_A_o) * (self.alpha / self.r)
        
        if not self._qkv_same_embed_dim:
#             assert False
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.base_layer.embed_dim, self.base_layer.num_heads,
                in_proj_weight + self.base_layer.in_proj_weight, self.base_layer.in_proj_bias,
                self.base_layer.bias_k, self.base_layer.bias_v, self.base_layer.add_zero_attn,
                self.base_layer.dropout, o_weight + self.base_layer.out_proj.weight, self.base_layer.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.base_layer.q_proj_weight, k_proj_weight=self.base_layer.k_proj_weight,
                v_proj_weight=self.base_layer.v_proj_weight)
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.base_layer.embed_dim, self.base_layer.num_heads,
                in_proj_weight + self.base_layer.in_proj_weight, self.base_layer.in_proj_bias,
                self.base_layer.bias_k, self.base_layer.bias_v, self.base_layer.add_zero_attn,
                self.base_layer.dropout, o_weight + self.base_layer.out_proj.weight, self.base_layer.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)

        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

    def forward(self,
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor,
            **kwargs):

        return self.forward_module(query, key, value, **kwargs) 
    
def replace_linear_with_lora(
    module: nn.Module,
    r: int = 8,
    alpha: int = 16,
    dropout_p: float = 0.0,
    embed_requires_grad: bool = False,      # embedding 层是否训练
    norm_requires_grad: bool = False,       # norm 层是否训练
    head_requires_grad: bool = False,       # lm_head 层是否训练（Causal LM才有）
    test_mode: bool = False,                # 测试模式，用于控制 lora_B 是否为全零
):
    """
    找到 module 中所有线性层并递归替换
    """
    for name, child in module.named_children():
        # 先处理额外的层，lm_head 也是 linear，所以先处理
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            requires_grad = embed_requires_grad if 'embed' in name \
                            else norm_requires_grad if 'norm' in name \
                            else head_requires_grad
            for param in child.parameters():
                param.requires_grad = requires_grad
        # 替换所有线性层，QLoRA 做法
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, r=r, alpha=alpha, dropout_p=dropout_p, test_mode=test_mode)
            setattr(module, name, lora_linear)
        elif isinstance(child, nn.MultiheadAttention):
            lora_attention = PlainMultiheadAttentionLoRA(child, r=r, lora_alpha=alpha, dropout_rate=dropout_p, test_mode=test_mode)
            setattr(module, name, lora_attention)
        # 递归向下替换
        else:
            replace_linear_with_lora(
                child, r, alpha, dropout_p,
                embed_requires_grad, norm_requires_grad, head_requires_grad,
                test_mode=test_mode
            )

def print_trainable_parameters(model: nn.Module):
    """
    打印可训练参数，表现和 PeftModel 的 print_trainable_parameters 方法类似
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    trainable_percentage = 100 * trainable_params / total_params
  
    # 返回可训练参数量、所有参数量、可训练参数量占比（百分比）
    print(f"trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percentage:.4f}")

def unload_lora(module: nn.Module, adapter_name: str = 'adapter'):
    """
    卸载 lora 参数，并将原模型恢复至加载 lora 前的样子
    """
    lora_parameters = {}
    def search_lora_linear(module: nn.Module, prefix: List[str]):
        for name, child in module.named_children():
            new_prefix = prefix + [name]
            if isinstance(child, LoraLinear):
                # 保存 lora 参数
                lora_parameters['.'.join(new_prefix)] = {
                    "lora_A_weight": child.lora_A.data.cpu(),
                    "lora_B_weight": child.lora_B.data.cpu(),
                    "r": child.r,
                    "alpha": child.alpha,
                    "dropout_p": child.dropout.p,
                }
                setattr(module, name, child.base_layer)
            else:
                search_lora_linear(child, new_prefix)

    search_lora_linear(module, [])
    # 解冻原模型
    for name, param in module.named_parameters():
        param.requires_grad = True

    torch.save(lora_parameters, f"{adapter_name}.pt")

def load_lora(module: nn.Module, adapter_name: str = 'adapter'):
    """
    加载 lora 参数
    """
    lora_parameters = torch.load(f"{adapter_name}.pt")

    for name, lora_params in lora_parameters.items():
        child = dict(module.named_modules())[name]
        if isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, lora_params['r'], lora_params['alpha'], lora_params['dropout_p'])
            lora_linear.lora_A.data = lora_params["lora_A_weight"].to(lora_linear.lora_A.device)
            lora_linear.lora_B.data = lora_params["lora_B_weight"].to(lora_linear.lora_B.device)

            # 名称示例：layers.0.self_attn.q_proj
            # 根据名称循环找到所需 module
            parts = name.split(".")
            obj = module
            for part in parts[:-1]:  # 不包括最后一级
                obj = getattr(obj, part)
            setattr(obj, parts[-1], lora_linear)

    # 恢复原来的冻结方式，这里简单地除了 lora 全冻结
    for name, param in module.named_parameters():
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            param.requires_grad = False

def merge_lora(module: nn.Module):
    """
    将 lora 参数合并到原来的 base_layer 中，并将 lora 层替换回原来的 nn.Linear 层
    """
    def search_lora_linear(module: nn.Module, prefix: List[str]):
        for name, child in module.named_children():
            new_prefix = prefix + [name]
            if isinstance(child, LoraLinear):
                # 合并 lora 参数到 base_layer
                with torch.no_grad():
                    lora_adjustment = torch.matmul(child.lora_B, child.lora_A) * (child.alpha / child.r)
                    child.base_layer.weight.add_(lora_adjustment)
                
                # 替换回原来的 base_layer
                setattr(module, name, child.base_layer)
            else:
                search_lora_linear(child, new_prefix)

    search_lora_linear(module, [])
    # 解冻原模型
    for name, param in module.named_parameters():
        param.requires_grad = True